r"""
Methods for evaluating performance.
"""
from typing import Any, Dict, Optional

import torch

from botorch.acquisition.analytic import PosteriorMean
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.models.model import Model
from botorch.optim.optimize import optimize_acqf

from botorch.test_functions.base import (
    BaseTestProblem,
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.utils.multi_objective.box_decompositions.dominated import (
    DominatedPartitioning,
)
from .experiment_utils import (
    eval_problem,
    get_best_feasible_f,
)
from torch import Tensor


def evaluate_performance(
    base_function: BaseTestProblem,
    Y_true: Optional[Tensor] = None,
    model: Optional[Model] = None,
    posterior_transform: Optional[PosteriorTransform] = None,
    mc_objective: Optional[MCAcquisitionObjective] = None,
    bounds: Optional[Tensor] = None,
    **optimization_kwargs: Any,
) -> Dict[str, Tensor]:
    """Evaluate in-sample performance.

    For MOO, this computes the in-sample HV. For single objective
    optimization, this computes the best observed objective.

    Args:
        base_function: A test problem.
        Y_true: The observed true values (this is an omniscient evaluation
            if there is observation noise).
    """
    if isinstance(base_function, MultiObjectiveTestProblem):
        # Compute in-sample HV
        if Y_true is not None:
            bd = DominatedPartitioning(ref_point=base_function.ref_point, Y=Y_true)
            in_sample_hv = bd.compute_hypervolume().view(-1).cpu()
        else:
            in_sample_hv = None
        return {
            "in_sample_hv": in_sample_hv,
            "pareto_Y": bd.pareto_Y,
        }  # IDEA: inferred hypervolume, but maybe not necessary b/c no KG, ES, etc.
    else:  # single-objective
        performance_dict = {}
        if isinstance(base_function, ConstrainedBaseTestProblem):
            performance_dict["best_obj"] = get_best_feasible_f(
                obj=Y_true[..., [0]],
                cons=Y_true[..., 1:],
                allow_inf=True,  # need to know if it isn't feasible.
            ).cpu()
            # TODO: potentially add slack (constraint satisfaction)
        else:
            performance_dict["best_obj"] = Y_true.max().cpu()
            performance_dict["best_inferred_obj"] = compute_best_inferred_objective(
                base_function=base_function,
                model=model,
                posterior_transform=posterior_transform,
                bounds=bounds,
                **optimization_kwargs,
            )
    return performance_dict


def compute_best_inferred_objective(
    base_function: BaseTestProblem,
    model: Model,
    posterior_transform: Optional[PosteriorTransform] = None,
    bounds: Optional[Tensor] = None,
    **optimization_kwargs,
) -> Tensor:
    """Computes the best inferred objective, i.e. the true (noiseless) objective value
    corresponding to the optimizer of the posterior mean of the model.

    Args:
        base_function: A test problem.
        model: A fitted model.
        posterior_transform: An optional posterior transform.
        bounds: A `2 x d` tensor of lower and upper bounds, on which to optimize
            *the model*, defaults to [0, 1]^d.
        **optimization_kwargs: Additional keyword arguments to pass to `optimize_acqf`.
    """
    posterior_mean = PosteriorMean(
        model=model, posterior_transform=posterior_transform, maximize=True
    )
    if bounds is None:
        bounds = torch.zeros(2, base_function.dim, device=model.device)
        bounds[1] = 1.0

    num_restarts = optimization_kwargs.pop("num_restarts", 16)
    raw_samples = optimization_kwargs.pop("raw_samples", 1024)
    # Optimize the acqf.
    torch.cuda.empty_cache()
    candidate, _ = optimize_acqf(
        acq_function=posterior_mean,
        bounds=bounds,
        q=1,
        num_restarts=num_restarts,
        raw_samples=raw_samples,
        **optimization_kwargs,
    )
    return eval_problem(base_function=base_function, X=candidate)[0]


def update_all_performance_summary(
    Y_true: Tensor,
    base_function: MultiObjectiveTestProblem,
    all_performance_summary: Dict[str, Tensor],
    model: Optional[Model] = None,
    posterior_transform: Optional[PosteriorTransform] = None,
    bounds: Optional[Tensor] = None,
    **optimization_kwargs: Any,
) -> None:
    """Update performance summary in-place.

    Args:
        base_function: A test problem.
        Y_true: The observed true values (this is an omniscient evaluation
            if there is observation noise).
        all_performance_summary: A dictionary mapping performance metric to a
            tensor containing value of the metric at each previous evaluation.
    """
    performance_summary = evaluate_performance(
        Y_true=Y_true,
        base_function=base_function,
        model=model,
        posterior_transform=posterior_transform,
        bounds=bounds,
        **optimization_kwargs,
    )
    for k, v in performance_summary.items():
        all_performance_summary[k].append(v)
